-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- We use UndecidableInstances for type families with recursive definitions
-- like "\\".  Those instances will terminate since each equation unwraps one
-- cons cell of a type-level list.
{-# LANGUAGE UndecidableInstances #-}

module TensorFlow.Types
    ( TensorType(..)
    , TensorData(..)
    , TensorDataType(..)
    , Scalar(..)
    , Shape(..)
    , protoShape
    , Attribute(..)
    , DataType(..)
    , ResourceHandle
    , Variant
    -- * Lists
    , ListOf(..)
    , List
    , (/:/)
    , TensorTypeProxy(..)
    , TensorTypes(..)
    , TensorTypeList
    , fromTensorTypeList
    , fromTensorTypes
    -- * Type constraints
    , OneOf
    , type (/=)
    , OneOfs
    -- ** Implementation of constraints
    , TypeError
    , ExcludedCase
    , NoneOf
    , type (\\)
    , Delete
    , AllTensorTypes
    ) where

import Data.Bits (shiftL, (.|.))
import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word32, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
import Lens.Family2.Unchecked (adapter)
import Text.Printf (printf)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Data.Vector.Split (chunksOf)
import Proto.Tensorflow.Core.Framework.AttrValue
    ( AttrValue
    , AttrValue'ListValue
    )
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
    ( b
    , f
    , i
    , s
    , list
    , type'
    , shape
    , tensor
    )

import Proto.Tensorflow.Core.Framework.ResourceHandle
    (ResourceHandleProto)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
    (TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
    ( boolVal
    , doubleVal
    , floatVal
    , intVal
    , int64Val
    , resourceHandleVal
    , stringVal
    , uint32Val
    , uint64Val
    )

import Proto.Tensorflow.Core.Framework.TensorShape
    (TensorShapeProto)
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
    ( dim
    , size
    , unknownRank
    )
import Proto.Tensorflow.Core.Framework.Types (DataType(..))

import qualified TensorFlow.Internal.Raw as Raw
import qualified TensorFlow.Internal.FFI as FFI

type ResourceHandle = ResourceHandleProto

-- | Dynamic type.
-- TensorFlow variants aren't supported yet. This type acts a placeholder to
-- simplify op generation.
data Variant

-- | The class of scalar types supported by tensorflow.
class TensorType a where
    tensorType :: a -> DataType
    tensorRefType :: a -> DataType
    tensorVal :: Lens' TensorProto [a]

instance TensorType Float where
    tensorType _ = DT_FLOAT
    tensorRefType _ = DT_FLOAT_REF
    tensorVal = floatVal

instance TensorType Double where
    tensorType _ = DT_DOUBLE
    tensorRefType _ = DT_DOUBLE_REF
    tensorVal = doubleVal

instance TensorType Int32 where
    tensorType _ = DT_INT32
    tensorRefType _ = DT_INT32_REF
    tensorVal = intVal

instance TensorType Int64 where
    tensorType _ = DT_INT64
    tensorRefType _ = DT_INT64_REF
    tensorVal = int64Val

integral :: Integral a => Lens' [Int32] [a]
integral = under (adapter (fmap fromIntegral) (fmap fromIntegral))

instance TensorType Word8 where
    tensorType _ = DT_UINT8
    tensorRefType _ = DT_UINT8_REF
    tensorVal = intVal . integral

instance TensorType Word16 where
    tensorType _ = DT_UINT16
    tensorRefType _ = DT_UINT16_REF
    tensorVal = intVal . integral

instance TensorType Word32 where
    tensorType _ = DT_UINT32
    tensorRefType _ = DT_UINT32_REF
    tensorVal = uint32Val

instance TensorType Word64 where
    tensorType _ = DT_UINT64
    tensorRefType _ = DT_UINT64_REF
    tensorVal = uint64Val

instance TensorType Int16 where
    tensorType _ = DT_INT16
    tensorRefType _ = DT_INT16_REF
    tensorVal = intVal . integral

instance TensorType Int8 where
    tensorType _ = DT_INT8
    tensorRefType _ = DT_INT8_REF
    tensorVal = intVal . integral

instance TensorType ByteString where
    tensorType _ = DT_STRING
    tensorRefType _ = DT_STRING_REF
    tensorVal = stringVal

instance TensorType Bool where
    tensorType _ = DT_BOOL
    tensorRefType _ = DT_BOOL_REF
    tensorVal = boolVal

instance TensorType (Complex Float) where
    tensorType _ = DT_COMPLEX64
    tensorRefType _ = DT_COMPLEX64
    tensorVal = error "TODO (Complex Float)"

instance TensorType (Complex Double) where
    tensorType _ = DT_COMPLEX128
    tensorRefType _ = DT_COMPLEX128
    tensorVal = error "TODO (Complex Double)"

instance TensorType ResourceHandle where
    tensorType _ = DT_RESOURCE
    tensorRefType _ = DT_RESOURCE_REF
    tensorVal = resourceHandleVal

instance TensorType Variant where
    tensorType _ = DT_VARIANT
    tensorRefType _ = DT_VARIANT_REF
    tensorVal = error "TODO Variant"

-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }

-- | Types that can be converted to and from 'TensorData'.
--
-- 'S.Vector' is the most efficient to encode/decode for most element types.
class TensorType a => TensorDataType s a where
    -- | Decode the bytes of a 'TensorData' into an 's'.
    decodeTensorData :: TensorData a -> s a
    -- | Encode an 's' into a 'TensorData'.
    --
    -- The values should be in row major order, e.g.,
    --
    --   element 0:   index (0, ..., 0)
    --   element 1:   index (0, ..., 1)
    --   ...
    encodeTensorData :: Shape -> s a -> TensorData a

-- All types, besides ByteString and Bool, are encoded as simple arrays and we
-- can use Vector.Storable to encode/decode by type casting pointers.

-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> S.Vector a
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData

simpleEncode :: forall a . (TensorType a, Storable a)
             => Shape -> S.Vector a -> TensorData a
simpleEncode (Shape xs) v =
    if product xs /= fromIntegral (S.length v)
        then error $ printf
            "simpleEncode: bad vector length for shape %v: expected=%d got=%d"
            (show xs) (product xs) (S.length v)
        else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
  where
    dt = tensorType (undefined :: a)

instance TensorDataType S.Vector Float where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Double where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Int8 where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Int16 where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Int32 where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Int64 where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Word8 where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorDataType S.Vector Word16 where
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorDataType S.Vector Bool where
    decodeTensorData =
        S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
    encodeTensorData (Shape xs) =
        TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
      where
        fromBool x = if x then 1 else 0 :: Word8

instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a, TensorType a)
    => TensorDataType V.Vector a where
    decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
    encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)

instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
    decodeTensorData = error "TODO (Complex Float)"
    encodeTensorData = error "TODO (Complex Float)"

instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
    decodeTensorData = error "TODO (Complex Double)"
    encodeTensorData = error "TODO (Complex Double)"

instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
    -- Strings can be encoded in various ways, see [0] for an overview.
    --
    -- The data starts with an array of TF_TString structs (24 bytes each), one
    -- for each element in the tensor. In some cases, the actual string
    -- contents are inlined in the TF_TString, in some cases they are in the
    -- heap, in some cases they are appended to the end of the data.
    --
    -- When decoding, we delegate most of those details to the TString C API.
    -- However, when encoding, the TString C API is prone to memory leaks given
    -- the current design of tensorflow-haskell, so, instead we manually encode
    -- all the strings in the "offset" format, where none of the string data is
    -- stored in separate heap objects and so no destructor hook is necessary.
    --
    -- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
    decodeTensorData tensorData =
        if S.length bytes < minBytes
            then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
                         show (S.length bytes) ++ ", need at least " ++ show minBytes
            else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
      where
        bytes = FFI.tensorDataBytes $ unTensorData tensorData
        numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
        minBytes = Raw.sizeOfTString * numElements
    encodeTensorData (Shape xs) vec =
        TensorData $ FFI.TensorData xs dt byteVector
      where
        dt = tensorType (undefined :: ByteString)
        tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
        -- Add a string to an offset table and data blob.
        addString :: (Builder, Builder, Word32, Word32)
                  -> ByteString
                  -> (Builder, Builder, Word32, Word32)
        addString (table, strings, tableOffset, stringsOffset) str =
            ( table <> Builder.word32LE sizeField
                    <> Builder.word32LE offsetField
                    <> Builder.word32LE capacityField
                    <> Builder.word32LE 0
                    <> Builder.word32LE 0
                    <> Builder.word32LE 0
            , strings <> Builder.byteString str
            , tableOffset + fromIntegral Raw.sizeOfTString
            , stringsOffset + strLen
            )
          where
            strLen :: Word32 = fromIntegral $ B.length str
            -- TF_TString.size includes a union tag in the first two bits.
            sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
            -- offset is relative to the start of the TF_TString instance, so
            -- we add the remaining distance to the end of the table to the
            -- offset from the start of the string data.
            offsetField :: Word32 = tableSize - tableOffset + stringsOffset
            capacityField :: Word32 = strLen
        -- Encode all strings.
        (table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
        -- Concat offset table with data.
        bytes = table' <> strings'
        -- Convert to Vector Word8.
        byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes

newtype Scalar a = Scalar {unScalar :: a}
    deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
              RealFrac, IsString)

instance (TensorDataType V.Vector a, TensorType a) => TensorDataType Scalar a where
    decodeTensorData = Scalar . headFromSingleton . decodeTensorData
    encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])

headFromSingleton :: V.Vector a -> a
headFromSingleton x
    | V.length x == 1 = V.head x
    | otherwise = error $
                  "Unable to extract singleton from tensor of length "
                  ++ show (V.length x)


-- | Shape (dimensions) of a tensor.
--
-- TensorFlow supports shapes of unknown rank, which are represented as
-- @Nothing :: Maybe Shape@ in Haskell.
newtype Shape = Shape [Int64] deriving Show

instance IsList Shape where
    type Item Shape = Int64
    fromList = Shape . fromList
    toList (Shape ss) = toList ss

protoShape :: Lens' TensorShapeProto Shape
protoShape = under (adapter protoToShape shapeToProto)
  where
    protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
      where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
                  ++ showMessageShort p
    shapeToProto s' = defMessage & protoMaybeShape .~ Just s'

protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape = under (adapter protoToShape shapeToProto)
  where
    protoToShape :: TensorShapeProto -> Maybe Shape
    protoToShape p =
        if view unknownRank p
            then Nothing
            else Just (Shape (p ^.. dim . traverse . size))
    shapeToProto :: Maybe Shape -> TensorShapeProto
    shapeToProto Nothing =
        defMessage & unknownRank .~ True
    shapeToProto (Just (Shape ds)) =
        defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds


class Attribute a where
    attrLens :: Lens' AttrValue a

instance Attribute Float where
    attrLens = f

instance Attribute ByteString where
    attrLens = s

instance Attribute Int64 where
    attrLens = i

instance Attribute DataType where
    attrLens = type'

instance Attribute TensorProto where
    attrLens = tensor

instance Attribute Bool where
    attrLens = b

instance Attribute Shape where
    attrLens = shape . protoShape

instance Attribute (Maybe Shape) where
    attrLens = shape . protoMaybeShape

-- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where
    attrLens = list

instance Attribute [DataType] where
    attrLens = list . type'

instance Attribute [Int64] where
    attrLens = list . i

-- | A heterogeneous list type.
data ListOf f as where
    Nil :: ListOf f '[]
    (:/) :: f a -> ListOf f as -> ListOf f (a ': as)

infixr 5 :/

type family All f as :: Constraint where
    All f '[] = ()
    All f (a ': as) = (f a, All f as)

type family Map f as where
    Map f '[] = '[]
    Map f (a ': as) = f a ': Map f as

instance All Eq (Map f as) => Eq (ListOf f as) where
    Nil == Nil = True
    (x :/ xs) == (y :/ ys) = x == y && xs == ys
    -- Newer versions of GHC use the GADT to tell that the previous cases are
    -- exhaustive.
#if __GLASGOW_HASKELL__ < 800
    _ == _ = False
#endif

instance All Show (Map f as) => Show (ListOf f as) where
    showsPrec _ Nil = showString "Nil"
    showsPrec d (x :/ xs) = showParen (d > 10)
                                $ showsPrec 6 x . showString " :/ "
                                    . showsPrec 6 xs

type List = ListOf Identity

-- | Equivalent of ':/' for lists.
(/:/) :: a -> List as -> List (a ': as)
(/:/) = (:/) . Identity

infixr 5 /:/

-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
-- natural representation as a conjunction, i.e.,
--
-- @
--    a == Double || a == Float
-- @
--
-- into a disjunction like
--
-- @
--     a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
-- @
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
    -- Assert `TensorTypes' ts` to make error messages a little better.
    = (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)

type OneOfs ts as = (TensorTypes as, TensorTypes' ts,
                        NoneOfs (AllTensorTypes \\ ts) as)

type family NoneOfs ts as :: Constraint where
    NoneOfs ts '[] = ()
    NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)

data TensorTypeProxy a where
    TensorTypeProxy :: TensorType a => TensorTypeProxy a

type TensorTypeList = ListOf TensorTypeProxy

fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
    = tensorType (undefined :: t) : fromTensorTypeList ts

fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)

class TensorTypes (ts :: [*]) where
    tensorTypes :: TensorTypeList ts

instance TensorTypes '[] where
    tensorTypes = Nil

-- | A constraint that the input is a list of 'TensorTypes'.
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
    tensorTypes = TensorTypeProxy :/ tensorTypes

-- | A simpler version of the 'TensorTypes' class, that doesn't run
-- afoul of @-Wsimplifiable-class-constraints@.
--
-- In more detail: the constraint @OneOf '[Double, Float] a@ leads
-- to the constraint @TensorTypes' '[Double, Float]@, as a safety-check
-- to give better error messages.  However, if @TensorTypes'@ were a class,
-- then GHC 8.2.1 would complain with the above warning unless @NoMonoBinds@
-- were enabled.  So instead, we use a separate type family for this purpose.
-- For more details: https://ghc.haskell.org/trac/ghc/ticket/11948
type family TensorTypes' (ts :: [*]) :: Constraint where
    -- Specialize this type family when `ts` is a long list, to avoid deeply
    -- nested tuples of constraints.  Works around a bug in ghc-8.0:
    -- https://ghc.haskell.org/trac/ghc/ticket/12175
    TensorTypes' (t1 ': t2 ': t3 ': t4 ': ts)
        = (TensorType t1, TensorType t2, TensorType t3, TensorType t4
              , TensorTypes' ts)
    TensorTypes' (t1 ': t2 ': t3 ': ts)
        = (TensorType t1, TensorType t2, TensorType t3, TensorTypes' ts)
    TensorTypes' (t1 ': t2 ': ts)
        = (TensorType t1, TensorType t2, TensorTypes' ts)
    TensorTypes' (t ': ts) = (TensorType t, TensorTypes' ts)
    TensorTypes' '[] = ()

-- | A constraint checking that two types are different.
type family a /= b :: Constraint where
    a /= a = TypeError a ~ ExcludedCase
    a /= b = ()

-- | Helper types to produce a reasonable type error message when the Constraint
-- "a /= a" fails.
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
data TypeError a
data ExcludedCase

-- | An enumeration of all valid 'TensorType's.
type AllTensorTypes =
    -- NOTE: This list should be kept in sync with
    -- TensorFlow.OpGen.dtTypeToHaskell.
    -- TODO: Add support for Complex Float/Double.
    '[ Float
     , Double
     , Int8
     , Int16
     , Int32
     , Int64
     , Word8
     , Word16
     , ByteString
     , Bool
     ]

-- | Removes a type from the given list of types.
type family Delete a as where
    Delete a '[] = '[]
    Delete a (a ': as) = Delete a as
    Delete a (b ': as) = b ': Delete a as

-- | Takes the difference of two lists of types.
type family as \\ bs where
    as \\ '[] = as
    as \\ (b ': bs) = Delete b as \\ bs

-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
type family NoneOf ts a :: Constraint where
    -- Specialize this type family when `ts` is a long list, to avoid deeply
    -- nested tuples of constraints.  Works around a bug in ghc-8.0:
    -- https://ghc.haskell.org/trac/ghc/ticket/12175
    NoneOf (t1 ': t2 ': t3 ': t4 ': ts) a
        = (a /= t1, a /= t2, a /= t3, a /= t4, NoneOf ts a)
    NoneOf (t1 ': t2 ': t3 ': ts) a = (a /= t1, a /= t2, a /= t3, NoneOf ts a)
    NoneOf (t1 ': t2 ': ts) a = (a /= t1, a /= t2, NoneOf ts a)
    NoneOf (t1 ': ts) a = (a /= t1, NoneOf ts a)
    NoneOf '[] a = ()
